{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `pyro.contrib.funsor`, a new backend for Pyro - New primitives (Part 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "In this tutorial we'll cover the basics of `pyro.contrib.funsor`, a new backend for the Pyro probabilistic programming system that is intended to replace the current internals of Pyro and significantly expand its capabilities as both a modelling tool and an inference research platform.\n",
    "\n",
    "This tutorial is aimed at readers interested in developing custom inference algorithms and understanding Pyro's current and future internals. As such, the material here assumes some familiarity with the [generic Pyro API package](https://pyro.ai/api/) `pyroapi` and with Funsor. Additional documentation for Funsor can be found on [the Pyro website](https://funsor.pyro.ai/en/stable/), on [GitHub](https://github.com/pyro-ppl/funsor), and in the research paper [\"Functional Tensors for Probabilistic Programming.\"](https://arxiv.org/abs/1910.10775) Those who are less interested in such details should find that they can already use the general-purpose algorithms in `contrib.funsor` with their existing Pyro models via `pyroapi`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinterpreting existing Pyro models with `pyroapi`\n",
    "\n",
    "The new backend uses the `pyroapi` package to integrate with existing Pyro code.\n",
    "\n",
    "First, we import some dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "import torch\n",
    "import funsor\n",
    "from pyro import set_rng_seed as pyro_set_rng_seed\n",
    "\n",
    "funsor.set_backend(\"torch\")\n",
    "torch.set_default_dtype(torch.float32)\n",
    "pyro_set_rng_seed(101)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Importing `pyro.contrib.funsor` registers the `\"contrib.funsor\"` backend with `pyroapi`, which can now be passed as an argument to the `pyroapi.pyro_backend` context manager."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyro.contrib.funsor\n",
    "import pyroapi\n",
    "from pyroapi import handlers, infer, ops, optim, pyro\n",
    "from pyroapi import distributions as dist\n",
    "\n",
    "# this is already done in pyro.contrib.funsor, but we repeat it here\n",
    "pyroapi.register_backend(\"contrib.funsor\", dict(\n",
    "    distributions=\"pyro.distributions\",\n",
    "    handlers=\"pyro.contrib.funsor.handlers\",\n",
    "    infer=\"pyro.contrib.funsor.infer\",\n",
    "    ops=\"torch\",\n",
    "    optim=\"pyro.optim\",\n",
    "    pyro=\"pyro.contrib.funsor\",\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we're off! From here on, any `pyro.(...)` statement should be understood as dispatching to the new backend."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Two new primitives: `to_funsor` and `to_data`\n",
    "\n",
    "The first and most important new concept in `pyro.contrib.funsor` is the new pair of primitives `pyro.to_funsor` and `pyro.to_data`.\n",
    "\n",
    "These are *effectful* versions of `funsor.to_funsor` and `funsor.to_data`, i.e. versions whose behavior can be intercepted, controlled, or used to trigger side effects by Pyro's library of algebraic effect handlers. Let's briefly review these two underlying functions before diving into the effectful versions in `pyro.contrib.funsor`.\n",
    "\n",
    "As one might expect from the name, `to_funsor` takes as inputs objects that are not `funsor.Funsor`s and attempts to convert them into Funsor terms. For example, calling `funsor.to_funsor` on a Python number converts it to a `funsor.terms.Number` object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 <class 'funsor.terms.Number'>\n",
      "tensor(2.) <class 'funsor.tensor.Tensor'>\n"
     ]
    }
   ],
   "source": [
    "funsor_one = funsor.to_funsor(float(1))\n",
    "print(funsor_one, type(funsor_one))\n",
    "\n",
    "funsor_two = funsor.to_funsor(torch.tensor(2.))\n",
    "print(funsor_two, type(funsor_two))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly ,calling `funsor.to_data` on an atomic `funsor.Funsor` converts it to a regular Python object like a `float` or a `torch.Tensor`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 <class 'float'>\n",
      "tensor(2.) <class 'torch.Tensor'>\n"
     ]
    }
   ],
   "source": [
    "data_one = funsor.to_data(funsor.terms.Number(float(1), 'real'))\n",
    "print(data_one, type(data_one))\n",
    "\n",
    "data_two = funsor.to_data(funsor.Tensor(torch.tensor(2.), OrderedDict(), 'real'))\n",
    "print(data_two, type(data_two))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In many cases it is necessary to provide an output type to uniquely convert a piece of data to a `funsor.Funsor`. This also means that, strictly speaking, `funsor.to_funsor` and `funsor.to_data` are not inverses. For example, `funsor.to_funsor` will automatically convert Python strings to `funsor.Variable`s, but only when given an output `funsor.domains.Domain`, which serves as the type of the variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x OrderedDict([('x', Reals[2])]) Reals[2]\n"
     ]
    }
   ],
   "source": [
    "var_x = funsor.to_funsor(\"x\", output=funsor.Reals[2])\n",
    "print(var_x, var_x.inputs, var_x.output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, it is often impossible to convert objects to and from Funsor expressions uniquely without additional type information about inputs, as in the following example of a `torch.Tensor`, which could be converted to a `funsor.Tensor` in several ways.\n",
    "\n",
    "To resolve this ambiguity, we need to provide `to_funsor` and `to_data` with type information that describes how to convert positional dimensions to and from unordered named Funsor dimensions. This information comes in the form of dictionaries mapping batch dimensions to dimension names or vice versa.\n",
    "\n",
    "A key property of these mappings is that use the convention that dimension indices refer to *batch dimensions*, or dimensions not included in the *output shape*, which is treated as referring to the rightmost portion of the underlying PyTorch tensor shape, as illustrated in the example below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ambiguous tensor: shape = torch.Size([3, 1, 2])\n",
      "Case 1: inputs = OrderedDict(), output = Reals[3,1,2]\n",
      "Case 2: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[1,2]\n",
      "Case 3: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[2]\n",
      "Case 4: inputs = OrderedDict([('a', Bint[3, ]), ('c', Bint[2, ])]), output = Real\n"
     ]
    }
   ],
   "source": [
    "ambiguous_tensor = torch.zeros((3, 1, 2))\n",
    "print(\"Ambiguous tensor: shape = {}\".format(ambiguous_tensor.shape))\n",
    "\n",
    "# case 1: treat all dimensions as output/event dimensions\n",
    "funsor1 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[3, 1, 2])\n",
    "print(\"Case 1: inputs = {}, output = {}\".format(funsor1.inputs, funsor1.output))\n",
    "\n",
    "# case 2: treat the leftmost dimension as a batch dimension\n",
    "# note that dimension -1 in dim_to_name here refers to the rightmost *batch dimension*,\n",
    "# i.e. dimension -3 of ambiguous_tensor, the rightmost dimension not included in the output shape.\n",
    "funsor2 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[1, 2], dim_to_name={-1: \"a\"})\n",
    "print(\"Case 2: inputs = {}, output = {}\".format(funsor2.inputs, funsor2.output))\n",
    "\n",
    "# case 3: treat the leftmost 2 dimensions as batch dimensions; empty batch dimensions are ignored\n",
    "# note that dimensions -1 and -2 in dim_to_name here refer to the rightmost *batch dimensions*,\n",
    "# i.e. dimensions -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.\n",
    "funsor3 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[2], dim_to_name={-1: \"b\", -2: \"a\"})\n",
    "print(\"Case 3: inputs = {}, output = {}\".format(funsor3.inputs, funsor3.output))\n",
    "\n",
    "# case 4: treat all dimensions as batch dimensions; empty batch dimensions are ignored\n",
    "# note that dimensions -1, -2 and -3 in dim_to_name here refer to the rightmost *batch dimensions*,\n",
    "# i.e. dimensions -1, -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.\n",
    "funsor4 = funsor.to_funsor(ambiguous_tensor, output=funsor.Real, dim_to_name={-1: \"c\", -2: \"b\", -3: \"a\"})\n",
    "print(\"Case 4: inputs = {}, output = {}\".format(funsor4.inputs, funsor4.output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar ambiguity exists for `to_data`: the `inputs` of a `funsor.Funsor` are ordered arbitrarily, and empty dimensions in the data are squeezed away, so a mapping from names to batch dimensions must be provided to ensure unique conversion:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ambiguous funsor: inputs = OrderedDict([('a', Bint[3, ]), ('b', Bint[2, ])]), shape = Real\n",
      "Case 1: shape = torch.Size([3, 2])\n",
      "Case 2: shape = torch.Size([3, 1, 2])\n",
      "Case 3: shape = torch.Size([2, 3])\n"
     ]
    }
   ],
   "source": [
    "ambiguous_funsor = funsor.Tensor(torch.zeros((3, 2)), OrderedDict(a=funsor.Bint[3], b=funsor.Bint[2]), 'real')\n",
    "print(\"Ambiguous funsor: inputs = {}, shape = {}\".format(ambiguous_funsor.inputs, ambiguous_funsor.output))\n",
    "\n",
    "# case 1: the simplest version\n",
    "tensor1 = funsor.to_data(ambiguous_funsor, name_to_dim={\"a\": -2, \"b\": -1})\n",
    "print(\"Case 1: shape = {}\".format(tensor1.shape))\n",
    "\n",
    "# case 2: an empty dimension between a and b\n",
    "tensor2 = funsor.to_data(ambiguous_funsor, name_to_dim={\"a\": -3, \"b\": -1})\n",
    "print(\"Case 2: shape = {}\".format(tensor2.shape))\n",
    "\n",
    "# case 3: permuting the input dimensions\n",
    "tensor3 = funsor.to_data(ambiguous_funsor, name_to_dim={\"a\": -1, \"b\": -2})\n",
    "print(\"Case 3: shape = {}\".format(tensor3.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Maintaining and updating this information efficiently becomes tedious and error-prone as the number of conversions increases. Fortunately, it can be automated away completely. Consider the following example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('x', -1)]) OrderedDict([('x', Bint[2, ])]) torch.Size([2])\n",
      "OrderedDict([('x', -1), ('y', -2)]) OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2])\n",
      "OrderedDict([('x', -1), ('y', -2), ('z', -3)]) OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1])\n"
     ]
    }
   ],
   "source": [
    "name_to_dim = OrderedDict()\n",
    "\n",
    "funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')\n",
    "name_to_dim.update({\"x\": -1})\n",
    "tensor_x = funsor.to_data(funsor_x, name_to_dim=name_to_dim)\n",
    "print(name_to_dim, funsor_x.inputs, tensor_x.shape)\n",
    "\n",
    "funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')\n",
    "name_to_dim.update({\"y\": -2})\n",
    "tensor_y = funsor.to_data(funsor_y, name_to_dim=name_to_dim)\n",
    "print(name_to_dim, funsor_y.inputs, tensor_y.shape)\n",
    "\n",
    "funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')\n",
    "name_to_dim.update({\"z\": -3})\n",
    "tensor_z = funsor.to_data(funsor_z, name_to_dim=name_to_dim)\n",
    "print(name_to_dim, funsor_z.inputs, tensor_z.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is exactly the functionality provided by `pyro.to_funsor` and `pyro.to_data`, as we can see by using them in the previous example and removing the manual updates. We must also wrap the function in a `handlers.named` effect handler to ensure that the dimension dictionaries do not persist beyond the function body."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('x', Bint[2, ])]) torch.Size([2, 1, 1, 1, 1])\n",
      "OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2, 1, 1, 1, 1])\n",
      "OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')\n",
    "    tensor_x = pyro.to_data(funsor_x)\n",
    "    print(funsor_x.inputs, tensor_x.shape)\n",
    "\n",
    "    funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')\n",
    "    tensor_y = pyro.to_data(funsor_y)\n",
    "    print(funsor_y.inputs, tensor_y.shape)\n",
    "\n",
    "    funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')\n",
    "    tensor_z = pyro.to_data(funsor_z)\n",
    "    print(funsor_z.inputs, tensor_z.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Critically, `pyro.to_funsor` and `pyro.to_data` use and update the same bidirectional mapping between names and dimensions, allowing them to be combined intuitively. A typical usage pattern, and one that `pyro.contrib.funsor` uses heavily in its inference algorithm implementations, is to create a `funsor.Funsor` term directly with a new named dimension and call `pyro.to_data` on it, perform some PyTorch computations, and call `pyro.to_funsor` on the result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'funsor.tensor.Tensor'> OrderedDict([('batch', Bint[3, ])]) Real\n",
      "<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ])]) Real\n",
      "<class 'pyro.distributions.torch.Bernoulli'> torch.Size([3, 1, 1, 1, 1])\n",
      "<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ]), ('batch', Bint[3, ])]) Real\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    \n",
    "    probs = funsor.Tensor(torch.tensor([0.5, 0.4, 0.7]), OrderedDict(batch=funsor.Bint[3]))\n",
    "    print(type(probs), probs.inputs, probs.output)\n",
    "    \n",
    "    x = funsor.Tensor(torch.tensor([0., 1., 0., 1.]), OrderedDict(x=funsor.Bint[4]))\n",
    "    print(type(x), x.inputs, x.output)\n",
    "    \n",
    "    dx = dist.Bernoulli(pyro.to_data(probs))\n",
    "    print(type(dx), dx.shape())\n",
    "    \n",
    "    px = pyro.to_funsor(dx.log_prob(pyro.to_data(x)), output=funsor.Real)\n",
    "    print(type(px), px.inputs, px.output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pyro.to_funsor` and `pyro.to_data` treat the keys in their name-to-dim mappings as references to the input's batch shape, but treats the values as references to the globally consistent name-dim mapping. This may be useful for complicated computations that involve a mixture of PyTorch and Funsor operations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ])]) Real\n",
      "px:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ]), ('y', Bint[3, ])]) Real\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    \n",
    "    x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: \"x\"})\n",
    "    print(\"x: \", type(x), x.inputs, x.output)\n",
    "    \n",
    "    px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: \"x\", -1: \"y\"})\n",
    "    print(\"px: \", type(px), px.inputs, px.output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dealing with large numbers of variables: (re-)introducing `pyro.markov`\n",
    "\n",
    "So far, so good. However, what if the number of different named dimensions continues to increase? We face two problems: first, reusing the fixed number of available positional dimensions (25 in PyTorch), and second, computing shape information with time complexity that is independent of the number of variables.\n",
    "\n",
    "A fully general automated solution to this problem would require deeper integration with Python or PyTorch. Instead, as an intermediate solution, we introduce the second key concept in `pyro.contrib.funsor`: the `pyro.markov` annotation, a way to indicate the shelf life of certain variables. `pyro.markov` is already part of Pyro (see enumeration tutorial) but the implementation in `pyro.contrib.funsor` is fresh.\n",
    "\n",
    "The primary constraint on the design of `pyro.markov` is backwards compatibility: in order for `pyro.contrib.funsor` to be compatible with the large range of existing Pyro models, the new implementation had to match the shape semantics of Pyro's existing enumeration machinery as closely as possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[2]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[3]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[4]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[8]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[9]:  torch.Size([2, 1, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    for i in pyro.markov(range(10)):\n",
    "        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({\"x{}\".format(i): funsor.Bint[2]})))\n",
    "        print(\"Shape of x[{}]: \".format(str(i)), x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pyro.markov` is a versatile piece of syntax that can be used as a context manager, a decorator, or an iterator. It is important to understand that `pyro.markov`'s only functionality at present is tracking variable usage, not directly indicating conditional independence properties to inference algorithms, and as such it is only necessary to add enough annotations to ensure that tensors have correct shapes, rather than attempting to manually encode as much dependency information as possible."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pyro.markov` takes an additional argument `history` that determines the number of previous `pyro.markov` contexts to take into account when building the mapping between names and dimensions at a given `pyro.to_funsor`/`pyro.to_data` call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[2]:  torch.Size([2, 1, 1, 1, 1, 1, 1])\n",
      "Shape of x[3]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[4]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1, 1])\n",
      "Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])\n",
      "Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])\n",
      "Shape of x[8]:  torch.Size([2, 1, 1, 1, 1, 1, 1])\n",
      "Shape of x[9]:  torch.Size([2, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    for i in pyro.markov(range(10), history=2):\n",
    "        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({\"x{}\".format(i): funsor.Bint[2]})))\n",
    "        print(\"Shape of x[{}]: \".format(str(i)), x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use cases beyond enumeration: global and visible dimensions\n",
    "\n",
    "### Global dimensions\n",
    "\n",
    "It is sometimes useful to have dimensions and variables ignore the `pyro.markov` structure of a program and remain active in arbitrarily deeply nested `markov` and `named` contexts. For example, suppose we wanted to draw a batch of samples from a Pyro model's joint distribution. To accomplish this we indicate to `pyro.to_data` that a dimension should be treated as \"global\" (`DimType.GLOBAL`) via the `dim_type` keyword argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType\n",
    "\n",
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    funsor_particle_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))\n",
    "    tensor_particle_ids = pyro.to_data(funsor_particle_ids, dim_type=DimType.GLOBAL)\n",
    "    print(\"New global dimension: \", funsor_particle_ids.inputs, tensor_particle_ids.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pyro.markov` does the hard work of automatically managing local dimensions, but because global dimensions ignore this structure, they must be deallocated manually or they will persist until the last active effect handler exits, just as global variables in Python persist until a program execution finishes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New global dimension:  OrderedDict([('plate1', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])\n",
      "Another new global dimension:  OrderedDict([('plate2', Bint[9, ])]) torch.Size([9, 1, 1, 1, 1, 1])\n",
      "A third new global dimension after recycling:  OrderedDict([('plate3', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType\n",
    "\n",
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    \n",
    "    funsor_plate1_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate1=funsor.Bint[10]))\n",
    "    tensor_plate1_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)\n",
    "    print(\"New global dimension: \", funsor_plate1_ids.inputs, tensor_plate1_ids.shape)\n",
    "    \n",
    "    funsor_plate2_ids = funsor.Tensor(torch.arange(9), OrderedDict(plate2=funsor.Bint[9]))\n",
    "    tensor_plate2_ids = pyro.to_data(funsor_plate2_ids, dim_type=DimType.GLOBAL)\n",
    "    print(\"Another new global dimension: \", funsor_plate2_ids.inputs, tensor_plate2_ids.shape)\n",
    "    \n",
    "    del _DIM_STACK.global_frame[\"plate1\"]\n",
    "    \n",
    "    funsor_plate3_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate3=funsor.Bint[10]))\n",
    "    tensor_plate3_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)\n",
    "    print(\"A third new global dimension after recycling: \", funsor_plate3_ids.inputs, tensor_plate3_ids.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Performing this deallocation directly is often unnecessary, and we include this interaction primarily to illuminate the internals of `pyro.contrib.funsor`. Instead, effect handlers that introduce global dimensions, like `pyro.plate`, may inherit from the `GlobalNamedMessenger` effect handler which deallocates global dimensions generically upon entry and exit. We will see an example of this in the next tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visible dimensions\n",
    "\n",
    "We might also wish to preserve the meaning of the shape of a tensor of data. For this we indicate to `pyro.to_data` that a dimension should be treated as not merely global but \"visible\" (`DimTypes.VISIBLE`). By default, the 4 rightmost batch dimensions are reserved as \"visible\" dimensions, but this can be changed by setting the `first_available_dim` attribute of the global state object `_DIM_STACK`.\n",
    "\n",
    "Users who have come across `pyro.infer.TraceEnum_ELBO`'s `max_plate_nesting` argument are already familiar with this distinction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor with new local dimension:  OrderedDict([('k', Bint[9, ])]) torch.Size([9, 1])\n",
      "Tensor with new global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1])\n",
      "Tensor with new visible dimension:  OrderedDict([('m', Bint[11, ])]) torch.Size([11])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-5"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prev_first_available_dim = _DIM_STACK.set_first_available_dim(-2)\n",
    "\n",
    "with pyroapi.pyro_backend(\"contrib.funsor\"), handlers.named():\n",
    "    \n",
    "    funsor_local_ids = funsor.Tensor(torch.arange(9), OrderedDict(k=funsor.Bint[9]))\n",
    "    tensor_local_ids = pyro.to_data(funsor_local_ids, dim_type=DimType.LOCAL)\n",
    "    print(\"Tensor with new local dimension: \", funsor_local_ids.inputs, tensor_local_ids.shape)\n",
    "    \n",
    "    funsor_global_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))\n",
    "    tensor_global_ids = pyro.to_data(funsor_global_ids, dim_type=DimType.GLOBAL)\n",
    "    print(\"Tensor with new global dimension: \", funsor_global_ids.inputs, tensor_global_ids.shape)\n",
    "    \n",
    "    funsor_data_ids = funsor.Tensor(torch.arange(11), OrderedDict(m=funsor.Bint[11]))\n",
    "    tensor_data_ids = pyro.to_data(funsor_data_ids, dim_type=DimType.VISIBLE)\n",
    "    print(\"Tensor with new visible dimension: \", funsor_data_ids.inputs, tensor_data_ids.shape)\n",
    "    \n",
    "# we also need to reset the first_available_dim after we're done\n",
    "_DIM_STACK.set_first_available_dim(prev_first_available_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visible dimensions are also global and must therefore be deallocated manually or they will persist until the last effect handler exits, as in the previous example. You may be thinking by now that Funsor's dimension names behave sort of like Python variables, with scopes and persistent meanings across expressions; indeed, this observation is the key insight behind the design of Funsor.\n",
    "\n",
    "Fortunately, interacting directly with the dimension allocator is almost always unnecessary, and as in the previous section we include it here only to illuminate the inner workings of `pyro.contrib.funsor`; rather, effect handlers like `pyro.handlers.enum` that may introduce non-visible dimensions that could conflict with visible dimensions should inherit from the base `pyro.contrib.funsor.handlers.named_messenger.NamedMessenger` effect handler. \n",
    "\n",
    "However, building a bit of intuition for the inner workings of the dimension allocator will make it easier to use the new primitives in `contrib.funsor` to build powerful new custom inference engines. We will see an example of one such inference engine in the next tutorial."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
